#!/usr/bin/python
# Copyright 2020 Makani Technologies LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Generate the message_stats.[ch] files."""

import collections
import os
import re
import sys
import textwrap

from makani.avionics.network import network_config


def _GetSenders(message_types):
  senders = collections.defaultdict(list)
  for message in message_types:
    for sender in message.all_senders:
      senders[sender].append(message)
  return senders


def _GetMaxMessagesPerSender(message_types):
  return max([len(x) for x in _GetSenders(message_types).itervalues()])


def _WriteMessageStatsHeader(script_name, file_without_extension, rel_path,
                             message_types):
  """Writes message_stats.h.

  Args:
    script_name: This script's filename.
    file_without_extension: The full path to the output file, missing the '.c'.
    rel_path: The relative path from the autogenerated files root to
              file_without_extension.
    message_types: The message_types field from the yaml file.
  """

  periods = []
  for message in message_types:
    if message.frequency_hz > 0:
      periods.append('#define %s_PERIOD_US %d' %
                     (message.snake_name.upper(),
                      1000000 / message.frequency_hz))
      periods.append('#define %s_FREQUENCY_HZ %d' %
                     (message.snake_name.upper(), message.frequency_hz))

  header_guard = re.sub('[/.]', '_', rel_path.upper()) + '_H_'
  parts = [textwrap.dedent("""
      #ifndef {guard}
      #define {guard}

      // Generated by {name}; do not edit.

      #include <stdint.h>

      #include "avionics/network/aio_node.h"
      #include "avionics/network/message_type.h"

      #ifdef __cplusplus
      extern "C" {{
      #endif

      typedef struct {{
        MessageType message;
        int32_t frequency_hz;
        int32_t period_us;
      }} MessageFrequencyInfo;

      #define MAX_MESSAGES_PER_SENDER {max_messages}
      typedef struct {{
        MessageFrequencyInfo frequency_info[MAX_MESSAGES_PER_SENDER];
        int32_t num_messages;
      }} MessageSenderInfo;

      extern const MessageSenderInfo kMessageSenderInfo[kNumAioNodes];

      int32_t GetMessageFrequency(AioNode sender, MessageType message);

      {periods}

      #ifdef __cplusplus
      }}  // extern "C"
      #endif

      #endif  // {guard}
      """[1:]).format(name=script_name, guard=header_guard,
                      periods='\n'.join(periods),
                      max_messages=_GetMaxMessagesPerSender(message_types))]

  with open(file_without_extension + '.h', 'w') as f:
    f.write('\n'.join(parts))


def _WriteMessageStatsSource(script_name, file_without_extension, rel_path,
                             message_types):
  """Writes message_stats.c for the non-winch messages.

  Args:
    script_name: This script's filename.
    file_without_extension: The full path to the output file, missing the '.c'.

    rel_path: The relative path from the autogenerated files root to
              file_without_extension.
    message_types: A list of message types.
  """

  parts = [textwrap.dedent("""
      // Generated by {name}; do not edit.

      #include <assert.h>

      #include "{header}"

      const MessageSenderInfo kMessageSenderInfo[kNumAioNodes] = {{"""[1:]
                          ).format(name=script_name, header=rel_path + '.h')]

  senders = _GetSenders(message_types)
  for sender in senders.iterkeys():
    parts.append('  [%s] = {{' % sender.enum_name)
    for message in senders[sender]:
      frequency = message.frequency_hz
      period = 1000000 / frequency if frequency > 0 else -1
      parts.append('    {%s, %d, %d},' % (message.enum_name, frequency, period))
    parts.append('    },')
    parts.append('    %d},' % len(senders[sender]))
  parts.append('};\n')

  parts.append(textwrap.dedent("""
      int32_t GetMessageFrequency(AioNode sender, MessageType message) {
        assert(IsValidNode(sender) || IsUnknownNode(sender));
        const MessageSenderInfo info = kMessageSenderInfo[sender];
        for (int32_t i = 0; i < info.num_messages; ++i) {
          if (info.frequency_info[i].message == message) {
            return info.frequency_info[i].frequency_hz;
          }
        }
        return 0;
      }
      """))

  with open(file_without_extension + '.c', 'w') as f:
    f.write('\n'.join(parts))


def _WriteMessageStats(autogen_root, output_dir, script_name, message_types):
  """Writes message_stats.[ch] for the non-winch messages.

  Args:
    autogen_root: The MAKANI_HOME-equivalent top directory.
    output_dir: The directory in which to output the files.
    script_name: This script's filename.
    message_types: A list of message types.
  """
  file_without_extension = os.path.join(output_dir, 'message_stats')
  rel_path = os.path.relpath(file_without_extension, autogen_root)

  _WriteMessageStatsSource(script_name, file_without_extension, rel_path,
                           message_types)
  _WriteMessageStatsHeader(script_name, file_without_extension, rel_path,
                           message_types)


def main(argv):
  flags, argv = network_config.ParseGenerationFlags(argv)

  config = network_config.NetworkConfig(flags.network_file)
  script_name = os.path.basename(argv[0])
  _WriteMessageStats(flags.autogen_root, flags.output_dir, script_name,
                     config.aio_messages)

if __name__ == '__main__':
  main(sys.argv)
